import torch
import argparse
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
import torchvision.utils as utils
import torch.optim as optim
from codebase import utils as ut
from codebase.models import nns
from torch import nn
from torch.nn import functional as F
from torch.nn import ReLU
from torch import Tensor
from torch import distributions as dist
from typing import List, Optional
from numpy.random import permutation, randint
from Conditioners import Conditioner, DAGConditioner
from numbers import Number
from typing import List, Tuple
device = torch.device("cuda:2" if(torch.cuda.is_available()) else "cpu")

# This implementation of GaussianMLP is copied from: https://github.com/siamakz/iVAE/blob/master/lib/models.py.
# This implementation of AffineNormalizerDAG is modified based on: https://github.com/AWehenkel/Graphical-Normalizing-Flows/blob/master/models/Normalizers/AffineNormalizer.py.

def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight.data)


class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, n_layers, activation='none', slope=.1, device='cpu'):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.n_layers = n_layers
        self.device = device
        if isinstance(hidden_dim, Number):
            self.hidden_dim = [hidden_dim] * (self.n_layers - 1)
        elif isinstance(hidden_dim, list):
            self.hidden_dim = hidden_dim
        else:
            raise ValueError('Wrong argument type for hidden_dim: {}'.format(hidden_dim))

        if isinstance(activation, str):
            self.activation = [activation] * (self.n_layers - 1)
        elif isinstance(activation, list):
            self.hidden_dim = activation
        else:
            raise ValueError('Wrong argument type for activation: {}'.format(activation))

        self._act_f = []
        for act in self.activation:
            if act == 'lrelu':
                self._act_f.append(lambda x: F.leaky_relu(x, negative_slope=slope))
            elif act == 'xtanh':
                self._act_f.append(lambda x: self.xtanh(x, alpha=slope))
            elif act == 'sigmoid':
                self._act_f.append(F.sigmoid)
            elif act == 'none':
                self._act_f.append(lambda x: x)
            else:
                ValueError('Incorrect activation: {}'.format(act))

        if self.n_layers == 1:
            _fc_list = [nn.Linear(self.input_dim, self.output_dim)]
        else:
            _fc_list = [nn.Linear(self.input_dim, self.hidden_dim[0])]
            for i in range(1, self.n_layers - 1):
                _fc_list.append(nn.Linear(self.hidden_dim[i - 1], self.hidden_dim[i]))
            _fc_list.append(nn.Linear(self.hidden_dim[self.n_layers - 2], self.output_dim))
        self.fc = nn.ModuleList(_fc_list)
        self.to(self.device)

    @staticmethod
    def xtanh(x, alpha=.1):
        """tanh function plus an additional linear term"""
        return x.tanh() + alpha * x

    def forward(self, x):
        h = x
        for c in range(self.n_layers):
            if c == self.n_layers - 1:
                h = self.fc[c](h)
            else:
                h = self._act_f[c](self.fc[c](h))
        return h


class Dist:
    def __init__(self):
        pass

    def sample(self, *args):
        pass

    def log_pdf(self, *args, **kwargs):
        pass


class Normal(Dist):
    def __init__(self, device='cpu'):
        super().__init__()
        self.device = device
        self.c = 2 * np.pi * torch.ones(1).to(self.device)
        self._dist = dist.normal.Normal(torch.zeros(1).to(self.device), torch.ones(1).to(self.device))
        self.name = 'gauss'

    def sample(self, mu, v):
        eps = self._dist.sample(mu.size()).squeeze()
        scaled = eps.mul(v.sqrt())
        return scaled.add(mu)

    def log_pdf(self, x, mu, v, reduce=True, param_shape=None):
        """compute the log-pdf of a normal distribution with diagonal covariance"""
        if param_shape is not None:
            mu, v = mu.view(param_shape), v.view(param_shape)
        lpdf = -0.5 * (torch.log(self.c) + v.log() + (x - mu).pow(2).div(v))
        if reduce:
            return lpdf.sum(dim=-1)
        else:
            return lpdf

    def log_pdf_full(self, x, mu, v):
        """
        compute the log-pdf of a normal distribution with full covariance
        v is a batch of "pseudo sqrt" of covariance matrices of shape (batch_size, d_latent, d_latent)
        mu is batch of means of shape (batch_size, d_latent)
        """
        batch_size, d = mu.size()
        cov = torch.einsum('bik,bjk->bij', v, v)  # compute batch cov from its "pseudo sqrt"
        assert cov.size() == (batch_size, d, d)
        inv_cov = torch.inverse(cov)  # works on batches
        c = d * torch.log(self.c)
        # matrix log det doesn't work on batches!
        _, logabsdets = self._batch_slogdet(cov)
        xmu = x - mu
        return -0.5 * (c + logabsdets + torch.einsum('bi,bij,bj->b', [xmu, inv_cov, xmu]))

    def _batch_slogdet(self, cov_batch: torch.Tensor):
        """
        compute the log of the absolute value of determinants for a batch of 2D matrices. Uses torch.slogdet
        this implementation is just a for loop, but that is what's suggested in torch forums
        gpu compatible
        """
        batch_size = cov_batch.size(0)
        signs = torch.empty(batch_size, requires_grad=False).to(self.device)
        logabsdets = torch.empty(batch_size, requires_grad=False).to(self.device)
        for i, cov in enumerate(cov_batch):
            signs[i], logabsdets[i] = torch.slogdet(cov)
        return signs, logabsdets

class GaussianMLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, n_layers, activation, slope, device, fixed_mean=None,
                 fixed_var=None):
        super().__init__()
        self.distribution = Normal(device=device)
        if fixed_mean is None:
            self.mean = MLP(input_dim, output_dim, hidden_dim, n_layers, activation=activation, slope=slope,
                            device=device)
        else:
            self.mean = lambda x: fixed_mean * torch.ones(1).to(device)
        if fixed_var is None:
            self.log_var = MLP(input_dim, output_dim, hidden_dim, n_layers, activation=activation, slope=slope,
                               device=device)
        else:
            self.log_var = lambda x: np.log(fixed_var) * torch.ones(1).to(device)

    def sample(self, *params):
        return self.distribution.sample(*params)

    def log_pdf(self, x, *params, **kwargs):
        return self.distribution.log_pdf(x, *params, **kwargs)


    def forward(self, *input):
        if len(input) > 1:
            x = torch.cat(input, dim=1)
        else:
            x = input[0]
        return self.mean(x), self.log_var(x).exp()

class AffineNormalizerDAG(nn.Module):
    def __init__(self,conditioner:Conditioner):
        super(AffineNormalizerDAG, self).__init__()
        self.conditioner=conditioner

    def forward(self, z, context=None):
        x = torch.zeros_like(z).to(device)
        log_det = torch.zeros(z.size(0),1).to(device)
        for i in range(z.size(1)):
            h = self.conditioner(x.clone())  
            mu, sigma = h[:, :, 0].clamp_(-2., 2.), h[:, :, 1].clamp_(-2., 2.)
            x[:, i] = mu[:, i] + z[:, i] * torch.exp(sigma[:,i])
        log_det += torch.sum(sigma,axis=1,keepdim=True)
        return x, log_det
   
class AffineNormalizerDAG_inference(nn.Module):
    def __init__(self,conditioner:Conditioner):
        super().__init__()
        self.conditioner=conditioner
        
    def forward(self, z, gap=3, n=10):
        dim = 4
        inference_x = torch.zeros((n * dim, dim)).to(device)
        for idx in range(dim):
            x = torch.zeros_like(z).to(device)
            x = x.expand(n,dim).clone()
            if idx == 0 or idx ==1:
                traversals = torch.linspace(-gap, +gap, steps=n)
                x[:, idx] = traversals
            elif idx==2:
                traversals = torch.linspace(-1.9, 0.5, steps=n)
                x[:, idx] = traversals
            else:
                traversals = torch.linspace(2.21, 0.35, steps=n)
                x[:, idx] = traversals
            for i in range(z.size(1)):
                if i==idx:
                    x[:, idx]=x[:, idx]
                else:
                    h = self.conditioner(x.clone())  
                    mu, sigma = h[:, :, 0].clamp_(-2., 2.), h[:, :, 1].clamp_(-2., 2.)
                    x[:, i] = mu[:, i] + z[:, i] * torch.exp(sigma[:,i])
            with torch.no_grad():
                inference_x[n * idx:(n * (idx + 1)), :] = x
        return inference_x
    
class CausalFlow(nn.Module):
    def __init__(self,conditioner:Conditioner,n_layers_CauF,dim=8, dim1=2,dim2=4,dim3=2):
        super().__init__()
        self.dim1=dim1
        self.dim2=dim2
        self.dim3=dim3
        self.dim=dim
        self.n_layers_CauF=n_layers_CauF
        self.layers = nn.ModuleList()
        self.conditioner=conditioner
        self.net = AffineNormalizerDAG(self.conditioner)
    def forward(self, z: Tensor):
        z1=z.to(device)
        z2=torch.zeros_like(z1).to(device)
        [B, _] = list(z1.size())
        log_det= torch.zeros(B, 1).to(device)
        if self.dim1!=0 and self.dim3!=0:
           z2[:,0:self.dim1]=z1[:,0:self.dim1]
           z2[:,self.dim1:self.dim1+self.dim2],inc=self.net(z1[:,self.dim1:self.dim1+self.dim2])
           z2[:,self.dim1+self.dim2:self.dim]=z1[:,self.dim1+self.dim2:self.dim]
        elif self.dim1==0 and self.dim3!=0:
           z2[:,self.dim1:self.dim1+self.dim2],inc=self.net(z1[:,self.dim1:self.dim1+self.dim2])
           z2[:,self.dim1+self.dim2:self.dim]=z1[:,self.dim1+self.dim2:self.dim]
        else: 
            z2[:,0:self.dim2],inc=self.net(z1[:,0:self.dim2])
        log_det=log_det+inc
        return z2,log_det

class CausalFlow_inference(nn.Module):
    def __init__(self,conditioner:Conditioner,n_layers_CauF,dim=8, dim1=2,dim2=4,dim3=2):
        super().__init__()
        self.dim1=dim1
        self.dim2=dim2
        self.dim3=dim3
        self.dim=dim
        self.n_layers_CauF=n_layers_CauF
        self.layers = nn.ModuleList()
        self.conditioner=conditioner
        self.net = AffineNormalizerDAG_inference(self.conditioner)
           
    def forward(self, z:Tensor,gap=3,n=10):#这里z也一维的
        z1=z.to(device)
        label_dim=4
        [B, _] = list(z1.size())
        dim=_
        z2=torch.zeros(n*label_dim,dim).to(device)
        z1=z1.expand(n*label_dim,dim)
        if self.dim1!=0 and self.dim3!=0:
           z2[:,0:self.dim1]=z1[:,0:self.dim1]
           
           z2[:,self.dim1:self.dim1+self.dim2]=self.net(z[:,self.dim1:self.dim1+self.dim2],gap,n)
           z2[:,self.dim1+self.dim2:self.dim]=z1[:,self.dim1+self.dim2:self.dim]
        elif self.dim1==0 and self.dim3!=0:
           z2[:,self.dim1:self.dim1+self.dim2]=self.net(z[:,self.dim1:self.dim1+self.dim2],gap,n)
           z2[:,self.dim1+self.dim2:self.dim]=z1[:,self.dim1+self.dim2:self.dim]
        else: 
           z2=self.net(z,gap,n)
        return z2
    
class Flow(nn.Module):
    def __init__(
        self,
        conditioner: Conditioner,
        dim: int,
        n_layers_CauF: int, 
        dim1=2,
        dim2=6,
        dim3=2):    
        super(Flow, self).__init__()
        self.conditioner=conditioner
        self.dim = dim
        self.dim1 = dim1
        self.dim2 = dim2
        self.dim3 = dim3
        self.n_layers_CauF=n_layers_CauF
        self.layers = nn.ModuleList()  
        self.layers.append(CausalFlow(self.conditioner,n_layers_CauF,dim,dim1,dim2,dim3))
        
    def forward(self, z):
        log_det_sum = torch.zeros(z.shape[0],1).to(device)
        # Forward pass.
        z1, inc = self.layers[0](z)
        log_det_sum = log_det_sum + inc
        return z,z1, log_det_sum
    
class CauFVAE(nn.Module):
    '''CauF-VAE model

       Args:
            Causal Flows:
                DAGConditioner: conditioner used in Causal Flows
                conditioner_args= {"in_size": args.dim2, "hidden": args.emb_net[:-1], "out_size": args.emb_net[-1],"l1":args.l1, "A_prior":A}
                n_layers_CauF: number of single-step flow in Causal Flows
            General settings:
                dim: latent dimension
                dim1: number of dimensions before factors' dimensions
                dim2: number of interested factors
                dim3: number of dimensions after factors' dimensions
                nn: file of Encoder and Decoder structure

    '''
    def __init__(self,
        DAGConditioner, 
        conditioner_args,
        dim=6,
        n_layers_CauF=1,
        dim1=0,
        dim2=4,
        dim3=2,
        nn='network'):
        super().__init__()
        self.conditioner = DAGConditioner(**conditioner_args)
        self.dim = dim
        self.dim1 = dim1
        self.dim2 = dim2
        self.dim3= dim3
        self.n_layers_CauF= n_layers_CauF
        self.u_dim=dim2
        # Small note: unfortunate name clash with torch.nn
        # nn here refers to the specific architecture file found in
        # codebase/models/nns/*.py
        nn = getattr(nns, nn)
        self.num_label=dim2
        self.enc = nn.ConvEncoder(self.dim,self.u_dim)
        self.dec = nn.ConvDecoder(self.dim)
        self.flow = Flow(self.conditioner,dim,n_layers_CauF,dim1,dim2,dim3)
        self.prior = GaussianMLP(self.dim2, self.dim2, hidden_dim=50, n_layers=3, activation='lrelu', slope=.1, device='cuda:2')
        self.scale = np.array([[1.5000,41.5000],[103.5000,43.5000],[7.5000, 4.5000],[10.5000,8.5000]])
        
    def getConditioners(self):
        return [self.conditioner]
                               
    def set_zero_grad(self):
        return self.conditioner.set_zero_grad()
    
    def encode(self, x, u=None):
        m, logv = self.enc.encode(x.to(device))
        return m, logv

    def transform(self, mean, logvar):
        std = torch.exp(.5 * logvar).to(device)
        eps = torch.randn_like(std).to(device)
        z = eps.mul(std).add_(mean)
        return self.flow(z)

    def decode(self, z):
        x=self.dec.decode(z)
        return x
    
    def traverse(self, x, gap=3, n=10):
        mean, logvar = self.encode(x)
        z1, z, log_det = self.transform(mean, logvar)
        dim = self.num_label if self.num_label is not None else self.dim
        sample = torch.zeros((n * dim, 3, 64, 64)).to(device)
        z=z.expand(n,self.dim)
        for idx in range(dim):
            if  idx == 1:
                traversals = torch.linspace(0.01, 0.99, steps=n)
            elif idx == 2:
                traversals = torch.linspace(-1.9, 0.5, steps=n)
            elif idx == 3:
                traversals = torch.linspace(2.21, 0.35, steps=n)
            else:
                traversals = torch.linspace(0.4, 1.4, steps=n)
            z_new = z.clone()
            z_new[:, idx] = traversals
            with torch.no_grad():
                sample[n * idx:(n * (idx + 1)), :, :, :] = self.decode(z_new).resize(n,3,64,64)
        return sample


    def intervene(self, x, gap=3,n=10):
        mean, logvar = self.encode(x)
        std = torch.exp(.5 * logvar).to(device)
        eps = torch.randn_like(std).to(device)
        z = eps.mul(std).add_(mean)
        z1,z2 ,log_det= self.flow(z)
        dim = self.num_label if self.num_label is not None else self.dim
        sample = torch.zeros((n * dim, 3, 64, 64)).to(device)
        with torch.no_grad():
            intervene1 = CausalFlow_inference(self.conditioner,self.n_layers_CauF,self.dim,self.dim1,self.dim2,self.dim3)
            sample =  self.decode(intervene1(z1,gap,n)).resize(n*dim,3,64,64)
        return sample
           
    def reconstruction_loss(self, x, x_hat):
        x_hat1=x_hat.reshape(x.size())
        log_prob=-nn.MSELoss(reduction='none')(input=x_hat1, target=x)
        log_prob=log_prob.sum(-1)
        log_prob=log_prob.sum(-1)
        rec=-torch.mean(log_prob).to(device)
        return rec
        
    def latent_loss(self, label, mean, logvar, log_det):
        label_=torch.zeros_like(label).to(device)
        for i in range(label.size()[0]):
            for j in range(label.size()[1]):
                mu1 = (float(label[i][j])-self.scale[j][0])/(self.scale[j][1]-0)
                label_[i][j] = mu1
        cp_m, cp_v = self.prior(label_)
        p_m=torch.zeros(label.size()[0],self.dim).to(device)
        p_m[:,self.dim1:self.dim1+self.dim2]=cp_m
        p_v=torch.ones(label.size()[0],self.dim).to(device)
        p_v[:,self.dim1:self.dim1+self.dim2]=cp_v
        var=torch.exp(logvar).to(device)
        kl=ut.kl_normal(mean.to(device), var.to(device), p_m.to(device), p_v.to(device))
        latent_loss=kl-log_det
        return latent_loss.mean()

    def sup_loss(self, z, label, sup_type, sup_prop, sup_coef):
        label1=torch.zeros_like(label).to(device)
        for i in range(label.size()[0]):
            for j in range(label.size()[1]):
                mu1 = (float(label[i][j])-self.scale[j][0])/(self.scale[j][1]-0)
                label1[i][j] = mu1
        celoss = torch.nn.BCEWithLogitsLoss()
        available_label_index = np.random.choice(label1.size()[0], int(label1.size()[0] * sup_prop), replace=0)#replace=0即为false
        z_sup_fake = z[available_label_index,self.dim1:self.dim1+self.dim2]
        label2 = label1[available_label_index,:]
        sup_loss = torch.zeros([1], device=device)
        if sup_coef != 0:
                    if sup_type == 'ce':
                        # CE loss
                        for i in range(len(available_label_index)):
                            tmp = celoss(z_sup_fake[i,:], label2[i,:])
                            sup_loss += tmp
                    else:
                        # l2 loss
                        for i in range(len(available_label_index)):
                            tmp = nn.MSELoss()(z_sup_fake[i,:], label2[i,:])
                            sup_loss += tmp             
        else:  
                 sup_loss = torch.zeros([1], device=device)
        loss_sup = sup_loss * sup_coef
        return loss_sup
   
    def loss(self, x, x_hat, label, scale, mean, logvar, log_det, z_sup, sup_type, sup_prop, sup_coef):
        loss=self.reconstruction_loss(x, x_hat) + self.latent_loss(label, mean, logvar, log_det)+self.sup_loss(z_sup, label, sup_type, sup_prop, sup_coef)
        
        return self.reconstruction_loss(x, x_hat),self.latent_loss(label, mean, logvar, log_det),self.sup_loss(z_sup, label, sup_type, sup_prop, sup_coef),loss
    
    def forward(self, x, label,scale, sup_type, sup_prop, sup_coef):
        mean, logvar = self.encode(x, label)
        z1, z, log_det = self.transform(mean, logvar)
        log_det=log_det.to(device)
        x_hat = self.decode(z)
        rec, KL,suploss,loss=self.loss(x, x_hat, label,scale, mean, logvar, log_det, z, sup_type,sup_prop, sup_coef)
        return rec, KL,suploss,loss, x_hat.reshape(x.size())
    
